#!/usr/bin/python
# encoding=UTF-8

# Copyright © 2008, 2009, 2010, 2011, 2012, 2014
#   Piotr Lewandowski <piotr.lewandowski@gmail.com>,
#   Jakub Wilk <jwilk@jwilk.net>.
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License, version 2, as
# published by the Free Software Foundation.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.

__author__ = ('Jakub Wilk', 'Piotr Lewandowski')
__version__ = '1.5.3'

from nntplib import NNTP, NNTP_PORT, NNTPTemporaryError, NNTPError
from cStringIO import StringIO
import argparse
import email
import email.generator
import errno
import functools
import logging
import logging.handlers
import mailbox
import nntplib
import os
import os.path
import signal
import socket
import sys

import plugins
import utils

def monkey_patch_mailbox_module():
    '''
    Work-around a bug in the `mailbox` module.
    See https://bugs.python.org/issue3228 for details.
    '''
    try:
        create_carefully = mailbox._create_carefully
    except AttributeError:
        return
    def create_carefully_666(*args, **kwargs):
        file = create_carefully(*args, **kwargs)
        mode = os.stat(file.name).st_mode & 0666
        os.chmod(file.name, mode)
        return file
    mailbox._create_carefully = create_carefully_666
monkey_patch_mailbox_module()
del monkey_patch_mailbox_module

# nntplib limits line length to 2048 bytes to prevent denial of service
# (CVE-2013-1752). But the line length limit doesn't buy us much, because
# sinntp loads the whole message into memory anyway. So let's lift the limit to
# 1 megabyte.
#
# References:
# * https://code.google.com/p/sinntp/issues/detail?id=9
# * https://bugs.python.org/issue16040
nntplib._MAXLINE = 1 << 20

class Config(object):

    def __init__(self, hostname, port):
        self._hostname = hostname
        self._port = port
        self._root = os.getenv('SINNTP_HOME')
        if self._root:
            logging.warn('$SINNTP_HOME is deprecated in favor of $XDG_DATA_HOME. See the NEWS file for details.')
            return
        self._root = os.path.expanduser('~/.sinntp/')
        if os.path.exists(self._root):
            logging.warn('$HOME/.sinntp/ is deprecated in favor of $XDG_DATA_HOME. See the NEWS file for details.')
            return
        self._root = utils.xdg.save_data_path('sinntp')

    def _get_file_name(self, name):
        base_name = '%s@%s' % (name, self._hostname)
        if self._port != NNTP_PORT:
            base_name += '_%d' % self._port
        return os.path.join(self._root, base_name)

    def __getitem__(self, name):
        try:
            file = open(self._get_file_name(name), 'rb')
        except IOError, ex:
            if ex.errno == errno.ENOENT:
                return 0
            raise
        try:
            return int(file.read())
        finally:
            file.close()

    def __setitem__(self, name, value):
        path = self._get_file_name(name)
        file = open(path + '_tmp', 'wb')
        try:
            file.write(str(value))
            os.fsync(file.fileno())
        finally:
            file.close()
        os.rename(path + '_tmp', path)

class Command(object):

    def get_option_parser(self):
        o = argparse.ArgumentParser(usage=self.__doc__.rstrip())
        o.add_argument('--version', action='version', version='%(prog)s ' + __version__,  help='show program\'s version number and exit')
        o.add_argument('-s', '--syslog', dest='syslog', action='store_true', help='use syslog for logging')
        o.add_argument('-v', '--verbose', dest='verbose', action='store_true', help='be more verbose')
        o.add_argument('-q', '--quiet', dest='verbose', action='store_false', help='be less verbose')
        o.add_argument('-p', '--plugin', dest='plugins', action='append', metavar='PLUGIN', help='use plugin')
        o.add_argument('-S', '--server', metavar='HOST[:PORT]', dest='server', action='store', help='specify server address')
        o.add_argument('-U', '--username', dest='username', action='store', help='specify username')
        o.add_argument('-P', '--password', dest='password', action='store', help='specify password')
        o.add_argument('--no-netrc', dest='netrc', action='store_false', help='ignore credentials in ~/.netrc')
        o.add_argument('-t', '--timeout', dest='timeout', action='store', type=int, help='specify connection timeout')
        return o

    def __init__(self, argv):
        oparser = self.get_option_parser()
        self.options, self.args = oparser.parse_known_args(argv)
        try:
            self.command_name = self.args.pop(0)
        except IndexError:
            oparser.error('A command is expected')
        self.plugins = []
        for spec in self.options.plugins or ():
            spec = spec.split(':')
            name = spec.pop(0)
            args = [arg for arg in spec if '=' not in arg]
            kwargs = dict(arg.split('=', 1) for arg in spec if '=' in arg)
            self.plugins += functools.partial(plugins.__dict__[name], *args, **kwargs),
        self.oparser = oparser

    def __call__(self):
        pass

class Pull(Command):
    '''
    nntp-pull [options] groupname[>filename] [groupname[>filename] ...]
    '''

    def get_option_parser(self):
        o = Command.get_option_parser(self)
        o.add_argument('--limit', dest='limit', type=int, action='store', metavar='N', help='pull at most N messages')
        o.add_argument('--reget', dest='reget', action='store_true', help='start from the first available message')
        return o

    def __init__(self, argv):
        Command.__init__(self, argv)
        if not self.args:
            self.oparser.error('At least one group name is required')
        self.groups = self.args
        del self.args

    def fetch(self, connection, group_name, start):
        logging.info('Looking for group %r.', group_name)
        response, count, first, last, name = connection.group(group_name)
        count, first, last = (int(x) for x in (count, first, last))
        i = max(start, first)
        count = max(last - i + 1, 0)
        logging.info('%(count)s message%(plural)s to download.' % dict(
            count = count if count else 'No',
            plural = 's' if count != 1 else ''
        ))
        while i <= last:
            try:
                connection.stat(str(i))
            except NNTPTemporaryError:
                i += 1
            else:
                break
        else:
            return
        no = str(i)
        while 1:
            _, _, message_id, body = connection.article(no)
            logging.debug('Reading message %s.', message_id)
            yield int(no), body
            try:
                _, no, _ = connection.next()
            except NNTPTemporaryError, ex:
                if ex.response.startswith('421'):
                    break
                raise

    def __call__(self, connection):
        config = Config(connection.host, connection.port)
        for group in self.groups:
            if '>' in group:
                group, mbox_name = group.split('>', 1)
            else:
                mbox_name = group
            mbox = None
            atime = None
            mode = None
            last = None
            start = config[group] if not self.options.reget else 0
            if self.options.limit is not None:
                _, _, _, end, _ = connection.group(group)
                start = max(start, int(end) - self.options.limit + 1)
            try:
                for no, message in self.fetch(connection, group, start):
                    if mbox is None:
                        mbox = mailbox.mbox(mbox_name, create=True)
                        mbox.lock()
                        stat = os.stat(mbox_name)
                        atime = stat.st_atime
                        mode = stat.st_mode
                    if not (message and message[-1].endswith('\n')):
                        message += ['']
                    message = '\n'.join(message)
                    message = email.message_from_string(message)
                    for plugin in self.plugins:
                        message = plugin(message=message)
                        if message is None:
                            break
                    if message is not None:
                        mbox.add(message)
                    last = no
            finally:
                if mbox is not None:
                    mbox.close()
                    if atime is not None:
                        mtime = os.stat(mbox_name).st_mtime
                        os.utime(mbox_name, (atime, mtime))
                    if mode is not None:
                        os.chmod(mbox_name, mode)
                if last is not None:
                    config[group] = last + 1

class Push(Command):
    '''
    nntp-push [options] [newsgroups...]
    '''

    def __call__(self, connection):
        message = email.message_from_file(sys.stdin)
        if not 'Newsgroups' in message and self.args:
            message['Newsgroups'] = ','.join(self.args)
        for plugin in self.plugins:
            message = plugin(message=message)
        buffer = StringIO()
        generator = email.generator.Generator(buffer, mangle_from_=False)
        generator.flatten(message)
        buffer.seek(0)
        connection.post(buffer)

class Get(Command):
    '''
    nntp-get message-id
    '''

    def __init__(self, argv):
        Command.__init__(self, argv)
        if len(self.args) != 1:
            self.oparser.error('A single message-id is required')
        [self.message_id] = self.args
        if '@' not in self.message_id:
            self.oparser.error('Message-id is malformed (\'@\' is missing)')
        if self.message_id[0] + self.message_id[-1] != '<>':
            self.message_id = '<%s>' % self.message_id
        del self.args

    def __call__(self, connection):
        logging.debug('Reading message %s.', self.message_id)
        _, _, _, message = connection.article(self.message_id)
        message = '\n'.join(message)
        print message

class List(Command):
    '''
    nntp-list
    '''

    def __init__(self, argv):
        Command.__init__(self, argv)
        if len(self.args) != 0:
            self.oparser.error('This command takes no arguments')

    def __call__(self, connection):
        _, groups = connection.list()
        print '\n'.join(group for group, _, _, flag in groups)

class BaseCommand(Command):

    COMMANDS = dict(
        pull = Pull,
        push = Push,
        get = Get,
        list = List,
    )

    __doc__ = ''.join(command.__doc__.rstrip() for command in COMMANDS.itervalues())

    def get_option_parser(self):
        oparser = Command.get_option_parser(self)
        return oparser

    def __init__(self, argv):
        Command.__init__(self, argv)
        try:
            self.command_class = self.COMMANDS[self.command_name]
        except KeyError:
            self.oparser.error('Unknown command: %r.' % self.command_name)
        sys.argv[0] = 'nntp-%s' % self.command_name

def setup_logging(syslog, verbose):
    logger = logging.getLogger()
    format = '%(message)s'
    if not syslog:
        handler = logging.StreamHandler()
    else:
        handler = logging.handlers.SysLogHandler(
            '/dev/log',
            facility = logging.handlers.SysLogHandler.LOG_NEWS
        )
        format = ''.join((sys.argv[0], '[%(process)d]: ', format))
    formatter = logging.Formatter(format, None)
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    level = {
        True:  logging.DEBUG,
        None:  logging.INFO,
        False: logging.ERROR,
    }.get(verbose)
    logger.setLevel(level)

class TerminatedBySignal(Exception):
    pass

def setup_signal_handler(signame):
    def handler(signo, frame):
        raise TerminatedBySignal(signame)
    signal.signal(getattr(signal, signame), handler)

if __name__ == '__main__':
    argv = sys.argv[1:]
    base_command = BaseCommand(argv)
    command = base_command.command_class(argv)
    setup_signal_handler('SIGTERM')
    setup_logging(command.options.syslog, command.options.verbose)
    host = command.options.server or os.getenv('NNTPSERVER') or \
        (os.path.exists('/etc/news/server') and file('/etc/news/server').readline().strip())
    if not host:
        sys.stderr.write('NNTP server is not specified.\n')
        sys.exit(2)
    host, port = utils.split_host(host, NNTP_PORT)
    socket.setdefaulttimeout(command.options.timeout)
    try:
        logging.info('Connecting to %s:%d...', host, port)
        connection = NNTP(host,
            port=port,
            user=command.options.username,
            password=command.options.password,
            readermode=True,
            usenetrc=command.options.netrc
        )
    except socket.error, e:
        logging.error('Could not connect to %s:%d: %s', host, port, e[1])
        sys.exit(3)
    except NNTPError, e:
        logging.error('%s:%d returned an error: %s', host, port, e)
        sys.exit(4)
    try:
        command(connection)
    except NNTPError, e:
        logging.error('NNTP error: %s', e)
        sys.exit(4)
    connection.quit()
    logging.info('Connection to %s:%d closed.', host, port)

# vim:ts=4 sw=4 et
